import torch
import numpy as np
from torch import nn


def Normalize(in_channels, num_groups=32):
	return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)

def nonlinearity(x): # swish
	return x*torch.sigmoid(x)


class AutoencoderKL(nn.Module):
	def __init__(self, config):
		super().__init__()
		self.embed_dim = config["embed_dim"]
		self.encoder = Encoder(**config)
		self.decoder = Decoder(**config)
		assert config["double_z"]
		self.quant_conv = torch.nn.Conv2d(2*config["z_channels"], 2*self.embed_dim, 1)
		self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, config["z_channels"], 1)

	def encode(self, x):
		h = self.encoder(x)
		moments = self.quant_conv(h)
		posterior = DiagonalGaussianDistribution(moments)
		return posterior.sample()

	def decode(self, z):
		z = self.post_quant_conv(z)
		dec = self.decoder(z)
		return dec

	def forward(self, input, sample_posterior=True):
		posterior = self.encode(input)
		if sample_posterior:
			z = posterior.sample()
		else:
			z = posterior.mode()
		dec = self.decode(z)
		return dec, posterior


class Encoder(nn.Module):
	def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
				 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
				 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
				 **ignore_kwargs):
		super().__init__()
		if use_linear_attn: attn_type = "linear"
		self.ch = ch
		self.temb_ch = 0
		self.num_resolutions = len(ch_mult)
		self.num_res_blocks = num_res_blocks
		self.resolution = resolution
		self.in_channels = in_channels

		# downsampling
		self.conv_in = torch.nn.Conv2d(in_channels,
									   self.ch,
									   kernel_size=3,
									   stride=1,
									   padding=1)

		curr_res = resolution
		in_ch_mult = (1,)+tuple(ch_mult)
		self.in_ch_mult = in_ch_mult
		self.down = nn.ModuleList()
		for i_level in range(self.num_resolutions):
			block = nn.ModuleList()
			attn = nn.ModuleList()
			block_in = ch*in_ch_mult[i_level]
			block_out = ch*ch_mult[i_level]
			for i_block in range(self.num_res_blocks):
				block.append(ResnetBlock(in_channels=block_in,
										 out_channels=block_out,
										 temb_channels=self.temb_ch,
										 dropout=dropout))
				block_in = block_out
				if curr_res in attn_resolutions:
					attn.append(make_attn(block_in, attn_type=attn_type))
			down = nn.Module()
			down.block = block
			down.attn = attn
			if i_level != self.num_resolutions-1:
				down.downsample = Downsample(block_in, resamp_with_conv)
				curr_res = curr_res // 2
			self.down.append(down)

		# middle
		self.mid = nn.Module()
		self.mid.block_1 = ResnetBlock(in_channels=block_in,
									   out_channels=block_in,
									   temb_channels=self.temb_ch,
									   dropout=dropout)
		self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
		self.mid.block_2 = ResnetBlock(in_channels=block_in,
									   out_channels=block_in,
									   temb_channels=self.temb_ch,
									   dropout=dropout)

		# end
		self.norm_out = Normalize(block_in)
		self.conv_out = torch.nn.Conv2d(block_in,
										2*z_channels if double_z else z_channels,
										kernel_size=3,
										stride=1,
										padding=1)

	def forward(self, x):
		# timestep embedding
		temb = None

		# downsampling
		hs = [self.conv_in(x)]
		for i_level in range(self.num_resolutions):
			for i_block in range(self.num_res_blocks):
				h = self.down[i_level].block[i_block](hs[-1], temb)
				if len(self.down[i_level].attn) > 0:
					h = self.down[i_level].attn[i_block](h)
				hs.append(h)
			if i_level != self.num_resolutions-1:
				hs.append(self.down[i_level].downsample(hs[-1]))

		# middle
		h = hs[-1]
		h = self.mid.block_1(h, temb)
		h = self.mid.attn_1(h)
		h = self.mid.block_2(h, temb)

		# end
		h = self.norm_out(h)
		h = nonlinearity(h)
		h = self.conv_out(h)
		return h


class Decoder(nn.Module):
	def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
				 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
				 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
				 attn_type="vanilla", post_quant_conv=None, **ignorekwargs):
		super().__init__()
		if use_linear_attn: attn_type = "linear"
		self.ch = ch
		self.temb_ch = 0
		self.num_resolutions = len(ch_mult)
		self.num_res_blocks = num_res_blocks
		self.resolution = resolution
		self.in_channels = in_channels
		self.give_pre_end = give_pre_end
		self.tanh_out = tanh_out
		self.post_quant_conv = post_quant_conv

		# compute in_ch_mult, block_in and curr_res at lowest res
		in_ch_mult = (1,)+tuple(ch_mult)
		block_in = ch*ch_mult[self.num_resolutions-1]
		curr_res = resolution // 2**(self.num_resolutions-1)
		self.z_shape = (1,z_channels,curr_res,curr_res)
		print("Working with z of shape {} = {} dimensions.".format(
			self.z_shape, np.prod(self.z_shape)))

		# z to block_in
		self.conv_in = torch.nn.Conv2d(z_channels,
									   block_in,
									   kernel_size=3,
									   stride=1,
									   padding=1)

		# middle
		self.mid = nn.Module()
		self.mid.block_1 = ResnetBlock(in_channels=block_in,
									   out_channels=block_in,
									   temb_channels=self.temb_ch,
									   dropout=dropout)
		self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
		self.mid.block_2 = ResnetBlock(in_channels=block_in,
									   out_channels=block_in,
									   temb_channels=self.temb_ch,
									   dropout=dropout)

		# upsampling
		self.up = nn.ModuleList()
		for i_level in reversed(range(self.num_resolutions)):
			block = nn.ModuleList()
			attn = nn.ModuleList()
			block_out = ch*ch_mult[i_level]
			for i_block in range(self.num_res_blocks+1):
				block.append(ResnetBlock(in_channels=block_in,
										 out_channels=block_out,
										 temb_channels=self.temb_ch,
										 dropout=dropout))
				block_in = block_out
				if curr_res in attn_resolutions:
					attn.append(make_attn(block_in, attn_type=attn_type))
			up = nn.Module()
			up.block = block
			up.attn = attn
			if i_level != 0:
				up.upsample = Upsample(block_in, resamp_with_conv)
				curr_res = curr_res * 2
			self.up.insert(0, up) # prepend to get consistent order

		# end
		self.norm_out = Normalize(block_in)
		self.conv_out = torch.nn.Conv2d(block_in,
										out_ch,
										kernel_size=3,
										stride=1,
										padding=1)

	def forward(self, z):
		#assert z.shape[1:] == self.z_shape[1:]
		self.last_z_shape = z.shape

		# timestep embedding
		temb = None

		# z to block_in
		h = self.conv_in(z)

		# middle
		h = self.mid.block_1(h, temb)
		h = self.mid.attn_1(h)
		h = self.mid.block_2(h, temb)

		# upsampling
		for i_level in reversed(range(self.num_resolutions)):
			for i_block in range(self.num_res_blocks+1):
				h = self.up[i_level].block[i_block](h, temb)
				if len(self.up[i_level].attn) > 0:
					h = self.up[i_level].attn[i_block](h)
			if i_level != 0:
				h = self.up[i_level].upsample(h)

		# end
		if self.give_pre_end:
			return h

		h = self.norm_out(h)
		h = nonlinearity(h)
		h = self.conv_out(h)
		if self.tanh_out:
			h = torch.tanh(h)
		return h


class DiagonalGaussianDistribution(object):
	def __init__(self, parameters, deterministic=False):
		self.parameters = parameters
		self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
		self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
		self.deterministic = deterministic
		self.std = torch.exp(0.5 * self.logvar)
		self.var = torch.exp(self.logvar)
		if self.deterministic:
			self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

	def sample(self):
		x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
		return x

	def kl(self, other=None):
		if self.deterministic:
			return torch.Tensor([0.])
		else:
			if other is None:
				return 0.5 * torch.sum(torch.pow(self.mean, 2)
									   + self.var - 1.0 - self.logvar,
									   dim=[1, 2, 3])
			else:
				return 0.5 * torch.sum(
					torch.pow(self.mean - other.mean, 2) / other.var
					+ self.var / other.var - 1.0 - self.logvar + other.logvar,
					dim=[1, 2, 3])

	def nll(self, sample, dims=[1,2,3]):
		if self.deterministic:
			return torch.Tensor([0.])
		logtwopi = np.log(2.0 * np.pi)
		return 0.5 * torch.sum(
			logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
			dim=dims)

	def mode(self):
		return self.mean


class Upsample(nn.Module):
	def __init__(self, in_channels, with_conv):
		super().__init__()
		self.with_conv = with_conv
		if self.with_conv:
			self.conv = torch.nn.Conv2d(in_channels,
										in_channels,
										kernel_size=3,
										stride=1,
										padding=1)

	def forward(self, x):
		# BF16 fix
		xh = x.to(torch.float32)
		xh = torch.nn.functional.interpolate(xh, scale_factor=2.0, mode="nearest")
		x = xh.to(x.dtype)

		if self.with_conv:
			x = self.conv(x)
		return x


class Downsample(nn.Module):
	def __init__(self, in_channels, with_conv):
		super().__init__()
		self.with_conv = with_conv
		if self.with_conv:
			# no asymmetric padding in torch conv, must do it ourselves
			self.conv = torch.nn.Conv2d(in_channels,
										in_channels,
										kernel_size=3,
										stride=2,
										padding=0)

	def forward(self, x):
		if self.with_conv:
			pad = (0,1,0,1)
			x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
			x = self.conv(x)
		else:
			x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
		return x
		

class ResnetBlock(nn.Module):
	def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
				 dropout, temb_channels=512):
		super().__init__()
		self.in_channels = in_channels
		out_channels = in_channels if out_channels is None else out_channels
		self.out_channels = out_channels
		self.use_conv_shortcut = conv_shortcut

		self.norm1 = Normalize(in_channels)
		self.conv1 = torch.nn.Conv2d(in_channels,
									 out_channels,
									 kernel_size=3,
									 stride=1,
									 padding=1)
		if temb_channels > 0:
			self.temb_proj = torch.nn.Linear(temb_channels,
											 out_channels)
		self.norm2 = Normalize(out_channels)
		self.dropout = torch.nn.Dropout(dropout)
		self.conv2 = torch.nn.Conv2d(out_channels,
									 out_channels,
									 kernel_size=3,
									 stride=1,
									 padding=1)
		if self.in_channels != self.out_channels:
			if self.use_conv_shortcut:
				self.conv_shortcut = torch.nn.Conv2d(in_channels,
													 out_channels,
													 kernel_size=3,
													 stride=1,
													 padding=1)
			else:
				self.nin_shortcut = torch.nn.Conv2d(in_channels,
													out_channels,
													kernel_size=1,
													stride=1,
													padding=0)

	def forward(self, x, temb):
		h = x
		h = self.norm1(h)
		h = nonlinearity(h)
		h = self.conv1(h)

		if temb is not None:
			h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]

		h = self.norm2(h)
		h = nonlinearity(h)
		h = self.dropout(h)
		h = self.conv2(h)

		if self.in_channels != self.out_channels:
			if self.use_conv_shortcut:
				x = self.conv_shortcut(x)
			else:
				x = self.nin_shortcut(x)

		return x+h


class AttnBlock(nn.Module):
	def __init__(self, in_channels):
		super().__init__()
		self.in_channels = in_channels

		self.norm = Normalize(in_channels)
		self.q = torch.nn.Conv2d(in_channels,
								 in_channels,
								 kernel_size=1,
								 stride=1,
								 padding=0)
		self.k = torch.nn.Conv2d(in_channels,
								 in_channels,
								 kernel_size=1,
								 stride=1,
								 padding=0)
		self.v = torch.nn.Conv2d(in_channels,
								 in_channels,
								 kernel_size=1,
								 stride=1,
								 padding=0)
		self.proj_out = torch.nn.Conv2d(in_channels,
										in_channels,
										kernel_size=1,
										stride=1,
										padding=0)


	def forward(self, x):
		h_ = x
		h_ = self.norm(h_)
		q = self.q(h_)
		k = self.k(h_)
		v = self.v(h_)

		# compute attention
		b,c,h,w = q.shape
		q = q.reshape(b,c,h*w)
		q = q.permute(0,2,1)   # b,hw,c
		k = k.reshape(b,c,h*w) # b,c,hw
		w_ = torch.bmm(q,k)	 # b,hw,hw	w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
		# w_ = w_ * (int(c)**(-0.5))
		w_ = w_ * (c**(-0.5))
		w_ = torch.nn.functional.softmax(w_, dim=2)

		# attend to values
		v = v.reshape(b,c,h*w)
		w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
		h_ = torch.bmm(v,w_)	 # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
		h_ = h_.reshape(b,c,h,w)

		h_ = self.proj_out(h_)

		return x+h_

def make_attn(in_channels, attn_type="vanilla"):
	assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
	print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
	if attn_type == "vanilla":
		return AttnBlock(in_channels)
	elif attn_type == "none":
		return nn.Identity(in_channels)
	else:
		return LinAttnBlock(in_channels)
